# A better Python cache for slow function calls
**William Zeng** - March 18th, 2024

---
We wrote a file cache - it's like Python's `lru_cache`, but it stores the values in files instead of in memory. This has saved us hours of time running our LLM benchmarks, and we'd like to share it as its own Python module. Thanks to [Luke Jaggernauth](https://github.com/lukejagg) (former Sweep engineer) for building the initial version of this!

Here's the link:
https://github.com/sweepai/sweep/blob/main/docs/public/file_cache.py. To use it, simply add the `file_cache` decorator to your function. Here's an example:
```python
import time
from file_cache import file_cache

@file_cache()
def slow_function(x, y):
    time.sleep(30)
    return x + y

print(slow_function(1, 2)) # -> 3, takes 30 seconds
print(slow_function(1, 2)) # -> 3, takes 0 seconds
```

## Background
We spend a lot of time prompt engineering our agents at Sweep. Our agents take a set of input strings, formats them as a prompt, then sends the prompt and any other information off to the LLM. We chain multiple agents to turn a GitHub issue to a pull request.
For example, to modify code we'll input the old code, any relevant context, and instructions then output the new code.

<div style={{ display: "flex", justifyContent: "center", alignItems: "center" }}>
    <img src="/assets/multi_llm_step.png" alt="Multiple llm steps being chained" style={{ height: "400px" }} />
</div>

A typical improvement involves tweaking a small part of our pipeline (like improving our planning algorithm), then running the entire pipeline again. 
We use pdb (python's native debugger) to set breakpoints and inspect the state of our prompts, input values, and parsing logic.
For example, we can check whether a certain string matches a regex:
```shell
(Pdb) print(todays_date)
'2024-03-14'
(Pdb) re.match(r"^\d{4}-\d{2}-\d{2}$", todays_date)
<re.Match object; span=(0, 10), match='2024-03-14'>
```
This lets us debug at runtime with the actual data.

## Cached pdb
pdb works great, but we have to wait for the entire pipeline to run again.
Imagine a version of pdb that not only interrupted execution, but also cached the entire program state up to that point.
Our time to hit that same bug could be cut from 10 minutes to 15 seconds (a 40x improvement).

We didn't build this, but we think our file_cache works just as well.

LLM calls are slow but their inputs and outputs are easy to cache, saving a lot of time. We can use the input prompt/string as the cache key, and the output string as the cache value.

### What's different from lru_cache?

lru_cache is great for memoizing repeated function calls, but it doesn't support two key features.

1. we need to persist the cache between runs. 
   - lru_cache stores the results in-memory, which means that the next time you run the program the cache will be empty. file_cache stores the results on disk. We also considered using Redis, but writing to disk is easier to set up/manage.

2. lru_cache doesn't support ignoring arguments that invalidate the cache. 
    - We'll use a custom `chat_logger` which stores the chats for visualization. It contains the current timestamp `chat_logger.expiration`, which will invalidate the cache if it's serialized.
    - To counteract this we added ignored parameters, used like this: `file_cache(ignore_params=["chat_logger"])`. This removes `chat_logger` from the cache key construction and prevents bad invalidation due to the constantly changing `expiration`.

## Implementation

Our two main methods are `recursive_hash` and `file_cache`.

### recursive_hash
We want to stably hash objects, and this is [not natively supported in python](https://death.andgravity.com/stable-hashing).

```python
import hashlib
from cache import recursive_hash


class Obj:
    def __init__(self, name):
        self.name = name

obj = Obj("test")
print(recursive_hash(obj)) # -> this works fine
try:
    hashlib.md5(obj).hexdigest()
except Exception as e:
    print(e) # -> this doesn't work
```

hashlib.md5 alone doesn't work for objects, giving us the error: `TypeError: object supporting the buffer API required`. 
We use recursive_hash, which works for arbitrary python objects.

```python /recursive_hash(item, depth + 1, ignore_params)/ /recursive_hash(key, depth + 1, ignore_params)/ /+ recursive_hash(val, depth + 1, ignore_params)/ /recursive_hash(value.__dict__, depth + 1, ignore_params)/
def recursive_hash(value, depth=0, ignore_params=[]):
    """Hash primitives recursively with maximum depth."""
    if depth > MAX_DEPTH:
        return hashlib.md5("max_depth_reached".encode()).hexdigest()

    if isinstance(value, (int, float, str, bool, bytes)):
        return hashlib.md5(str(value).encode()).hexdigest()
    elif isinstance(value, (list, tuple)):
        return hashlib.md5(
            "".join(
                [recursive_hash(item, depth + 1, ignore_params) for item in value]
            ).encode()
        ).hexdigest()
    elif isinstance(value, dict):
        return hashlib.md5(
            "".join(
                [
                    recursive_hash(key, depth + 1, ignore_params)
                    + recursive_hash(val, depth + 1, ignore_params)
                    for key, val in value.items()
                    if key not in ignore_params
                ]
            ).encode()
        ).hexdigest()
    elif hasattr(value, "__dict__") and value.__class__.__name__ not in ignore_params:
        return recursive_hash(value.__dict__, depth + 1, ignore_params)
    else:
        return hashlib.md5("unknown".encode()).hexdigest()
``` 

### file_cache
file_cache is a decorator that handles the caching logic for us.

```python
@file_cache()
def search_codebase(
    cloned_github_repo,
    query,
):
    # ... take a long time ...
    # ... llm agent logic to search through the codebase ...
    return top_results
```

Without the cache, searching through the codebase using our LLM agent to get `top_results` takes 5 minutes - way too long if we're not actually testing it. Instead with file_cache, we just need to wait for deserialization of the pickled object - basically instantaneous for search results.

#### Wrapper

First we store our cache in `/tmp/file_cache`. This lets us remove the cache by simply deleting the directory (running `rm -rf /tmp/file_cache`).
We can also selectively remove function calls using `rm -rf /tmp/file_cache/search_codebase*`.
```python
def wrapper(*args, **kwargs):
    cache_dir = "/tmp/file_cache"
    os.makedirs(cache_dir, exist_ok=True)
```

Then we can create a cache key.

#### Cache Key Creation / Miss Conditions
We have another problem - we want to miss our cache under two conditions:

1. The arguments to the function change - handled by `recursive_hash`
2. The code changes

To handle 2. we used `inspect.getsource(func)` to add the function's source code to the hash, correctly missing the cache when the the code changes.

```python /hash_code(inspect.getsource(func))/
    func_source_code_hash = hash_code(inspect.getsource(func))
```

```python /func_source_code_hash/
    args_names = func.__code__.co_varnames[: func.__code__.co_argcount]
    args_dict = dict(zip(args_names, args))

    # Remove ignored params
    kwargs_clone = kwargs.copy()
    for param in ignore_params:
        args_dict.pop(param, None)
        kwargs_clone.pop(param, None)

    # Create hash based on argument names, argument values, and function source code
    arg_hash = (
        recursive_hash(args_dict, ignore_params=ignore_params)
        + recursive_hash(kwargs_clone, ignore_params=ignore_params)
        + func_source_code_hash
    )
    cache_file = os.path.join(
        cache_dir, f"{func.__module__}_{func.__name__}_{arg_hash}.pickle"
    )
```

#### Cache hits and misses
Finally we check cache key existence and write to the cache in the case of a cache miss.

```python
    try:
        # If cache exists, load and return it
        if os.path.exists(cache_file):
            if verbose:
                print("Used cache for function: " + func.__name__)
            with open(cache_file, "rb") as f:
                return pickle.load(f)
    except Exception:
        logger.info("Unpickling failed")

    # Otherwise, call the function and save its result to the cache
    result = func(*args, **kwargs)
    try:
        with open(cache_file, "wb") as f:
            pickle.dump(result, f)
    except Exception as e:
        logger.info(f"Pickling failed: {e}")
    return result
```

## Conclusion

We hope this code is useful to you. We've found it to be a massive time saver when debugging LLM calls.
We'd love to hear your feedback and contributions at https://github.com/sweepai/sweep!
